package com.tinyproxy.remote;

import com.tinyproxy.common.Cfg;
import com.tinyproxy.common.ExchangeHandler;
import com.tinyproxy.common.Kits;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import static com.tinyproxy.common.CodecKind.DECRYPT;
import static com.tinyproxy.common.CodecKind.ENCRYPT;

public class RemoteDecoder extends ByteToMessageDecoder {

    private final Logger log = LoggerFactory.getLogger(getClass());

    private final TinyHead head = new TinyHead();

    private final byte[] staticToken;

    public RemoteDecoder(Cfg cfg) {
        this.staticToken = Kits.toSha256(Kits.asBytes(cfg.getToken()));
    }

    private void response(ChannelHandlerContext ctx, boolean success, String msg) {
        ByteBuf buff = Unpooled.buffer(4);
        int v = new Random().nextInt(Kits.RESP_MID);
        if (success) {
            v += Kits.RESP_MID + 100;// 大于RESP_MID, 成功
        }
        Kits.writeInt(buff, v);
        ChannelFuture future = ctx.writeAndFlush(buff);
        if (!success) {
            future.addListener(ChannelFutureListener.CLOSE);
        }
        if (!Kits.isBlank(msg)) {
            log.warn(msg);
        }
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        if (head.handledHead()) {
            return;
        }

        if (head.getSha() == null) {
            if (in.readableBytes() < 36) {
                return;
            }

            byte[] sha = new byte[32];
            in.readBytes(sha);
            int len = Kits.readInt(in);
            if (len < 0 || len > 1024 * 1024) {
                response(ctx, false, "Dangous package size, omit!");
                return;
            }
            head.setSha(sha);
            head.setPartLen(len);
        } else {
            if (in.readableBytes() < head.getPartLen()) {
                return;
            }

            byte[] finalData = new byte[head.getPartLen()];
            in.readBytes(finalData);
            Kits.decrypt(staticToken, finalData);// 解密
            String parts = Kits.toString(finalData); // host+\0+port+\0+uuid+\0+millis
            byte[] calcSha = Kits.toSha256(finalData);// 原始数据做SHA256
            if (!Kits.equalsMem(calcSha, head.getSha(), 0, 0, 32)) {
                response(ctx, false, "Auth validation error!");
                return;
            }

            // 加密数据+静态令牌后做SHA256，生成新的令牌
            Kits.encrypt(staticToken, finalData);
            head.setXtoken(Kits.toSha256(finalData, staticToken));

            String[] segs = parts.split("\0");
            if (segs.length != 4) {
                response(ctx, false, "Error header!");
                return;
            }

            head.setHost(segs[0]);
            head.setPort(Kits.toInt(segs[1], -1));
            if (head.getPort() < 1) {
                response(ctx, false, "Error port!");
                return;
            }

            head.setMillis(Kits.toLong(segs[3], -1));
            if (Math.abs(head.getMillis() - System.currentTimeMillis()) > TimeUnit.MINUTES.toMillis(20)) {
                log.warn("Token is timeout!");
                response(ctx, false, "Token is timeout!");
                return;
            }
            response(ctx, true, null);
            connectRemoteServer(ctx, head);
            if (in.readableBytes() > 0) {
                out.add(in.retainedSlice());
            }
        }
    }

    private void connectRemoteServer(ChannelHandlerContext ctx, TinyHead head) throws Exception {
        Channel inChannel = ctx.channel();

        Bootstrap bootstrap = new Bootstrap().group(ctx.channel().eventLoop());
        bootstrap.channel(NioSocketChannel.class);
        bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
        bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000);
        ChannelHandler emptyHandler = new ChannelInboundHandlerAdapter();
        bootstrap.handler(emptyHandler);
        ChannelFuture connectFuture = bootstrap.connect(head.getHost(), head.getPort());
        connectFuture.addListener(future -> {
            Channel outChannel = connectFuture.channel();
            if (future.isSuccess()) {
                ctx.pipeline().addLast(new ExchangeHandler(head.getXtoken(), outChannel, DECRYPT));
                ctx.pipeline().remove(RemoteDecoder.this);

                ChannelPipeline outPipeline = outChannel.pipeline();
                outPipeline.addLast(new ExchangeHandler(head.getXtoken(), inChannel, ENCRYPT));
                outPipeline.remove(emptyHandler);
            } else {
                Kits.closeOnFlush(inChannel);
            }
        });
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.warn("--Decoder--" + cause);
    }

}
